#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @FileName : gaze_onnx_dete
# @Author   : LesterXu
# @Email    : cy163xgj@yeah.net
# @Time     : 2024/6/18 14:42


import os
import cv2
import onnx
import numpy as np
import onnxruntime


# ---------------------------------------------------------#
#   以下
#   ONNX部分
#   ||||||||
# ---------------------------------------------------------#
def get_input_name(onnx_session):
    input_name = []
    for node in onnx_session.get_inputs():
        input_name.append(node.name)
    return input_name


def get_input_feed(onnx_session, img_tensor):
    input_name = get_input_name(onnx_session)
    print("input_name:", input_name)
    input_feed = {}
    for name in input_name:
        print("name:", name)
        input_feed[name] = img_tensor
    return input_feed


if __name__ == '__main__':
    imgs_path = r'.'
    onnx_path = r'best_model.onnx'
    # onnx 推理
    hand_det_landmark_onnx_session = onnxruntime.InferenceSession(onnx_path)
    # pre
    for img_name in os.listdir(imgs_path):
        if img_name.endswith('.jpg') or img_name.endswith('png'):
            img_path = os.path.join(imgs_path, img_name)
            print(img_path)
            img = cv2.imread(img_path)
            # 预处理
            img = cv2.resize(img, (64, 64))
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW
            img = img.astype(np.float32)
            img /= 255  # 0 - 255 to 0.0 - 1.0
            mean = 0.5
            std = 0.5
            image_normalized = (img - mean) / std

            if len(img.shape) == 3:
                image_normalized = image_normalized[None]  # expand for batch dim  3,64,64 --> 1,3,64,64
            # # 加载ONNX ONNX 推理
            input_feed = get_input_feed(hand_det_landmark_onnx_session, image_normalized)
            output_name = hand_det_landmark_onnx_session.get_outputs()[0].name

            pred = hand_det_landmark_onnx_session.run([output_name], input_feed)
            pred = (np.array(pred))[0][0]
            classes = ['front', 'left_mirror', 'left_window', 'middle_mirror', 'right_mirror', 'right_window']
            print(classes[np.argmax(pred)])


